{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# pyTorch 图像识别教程\n",
    "\n",
    "\n",
    "这里以 TinyMind 《汉字书法识别》比赛数据为例，展示使用 Pytorch 进行图像数据分类模型训练的整个流程。\n",
    "\n",
    "数据地址请参考:\n",
    "https://www.tinymind.cn/competitions/41#property_23\n",
    "\n",
    "或到这里下载：\n",
    "自由练习赛数据下载地址：\n",
    "训练集：链接: https://pan.baidu.com/s/1UxvN7nVpa0cuY1A-0B8gjg 密码: aujd\n",
    "\n",
    "测试集: https://pan.baidu.com/s/1tzMYlrNY4XeMadipLCPzTw 密码: 4y9k"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据探索\n",
    "请参考官方的数据说明"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据处理\n",
    "\n",
    "竞赛中只有训练集 train 数据有准确的标签，因此这里只使用 train 数据即可，实际应用中，阶段 1、2 的榜单都需要使用。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据下载\n",
    "\n",
    "下载数据之后进行解压，得到 train 文件夹，里面有 100 个文件夹，每个文件夹名字即是各个汉字的标签。类似的数据集结构经常在分类任务中见到。可以使用下述命令验证一下每个文件夹下面文件的数量，看数据集是否符合竞赛数据描述：\n",
    "```sh\n",
    "for l in $(ls); do echo $l $(ls $l|wc -l); done\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 划分数据集\n",
    "\n",
    "因为这里只使用了 train 集，因此我们需要对已有数据集进行划分，供模型训练的时候做验证使用，也就是 validation 集的构建。\n",
    "> 一般认为，train 用来训练模型，validation 用来对模型进行验证以及超参数（ hyper parameter）调整，test 用来做模型的最终验证，我们所谓模型的性能，一般也是指 test 集上模型的性能指标。但是实际项目中，一般只有 train 集，同时没有可靠的 test 集来验证模型，因此一般将 train 集划分出一部分作为 validation，同时将 validation 上的模型性能作为最终模型性能指标。\n",
    "\n",
    "> 一般情况下，我们不严格区分 validation 和 test。\n",
    "\n",
    "这里将每个文件夹下面随机50个文件拿出来做 validation。\n",
    "\n",
    "```sh\n",
    "export train=train\n",
    "export val=validation\n",
    "\n",
    "for d in $(ls $train); do\n",
    "    mkdir -p $val/$d/\n",
    "    for f in $(ls train/$d | shuf | head -n 50 ); do\n",
    "        mv $train/$d/$f $val/$d/;\n",
    "    done;\n",
    "done\n",
    "```\n",
    "\n",
    "> 需要注意，这里的 validation 只间接通过超参数的调整参与了模型训练。因此有一定的数据浪费。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练代码-数据部分\n",
    "首先导入 pyTorch 看一下版本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.4.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision as tv\n",
    "\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练模型的时候，模型内部全部都是数字，没有任何可读性，而且这些数字也需要人为给予一些实际的意义，这里将 100 个汉字作为模型输出数字的文字表述。\n",
    "\n",
    "需要注意的是，因为模型训练往往是一个循环往复的过程，因此一个稳定的文字标签是很有必要的，这里利用相关 python 代码在首次运行的时候生成了一个标签文件，后续检测到这个标签文件，则直接调用即可。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "if os.path.exists(\"labels.txt\"):\n",
    "    with open(\"labels.txt\") as inf:\n",
    "        classes = [l.strip() for l in inf]\n",
    "else:\n",
    "    classes = os.listdir(\"worddata/train/\")\n",
    "    with open(\"labels.txt\", \"w\") as of:\n",
    "        of.write(\"\\r\\n\".join(classes))\n",
    "\n",
    "class_idx = {v: k for k, v in enumerate(classes)}\n",
    "idx_class = dict(enumerate(classes))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "pyTorch里面，classes有自己的组织方式，这里我们想要自定义，要做一下转换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "pth_classes = classes[:]\n",
    "pth_classes.sort()\n",
    "pth_classes_to_idx = {v: k for k, v in enumerate(pth_classes)}\n",
    "\n",
    "\n",
    "def target_transform(pth_idx):\n",
    "    return class_idx[pth_classes[pth_idx]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "pyTorch 中提供了直接从目录中读取数据并进行训练的 API 这里使用的API如下。\n",
    "\n",
    "这里使用了两个数据集，分别代表 train、validation。\n",
    "\n",
    "需要注意的是，由于 数据中，使用的图像数据集，其数值在（0， 255）之间。同时，pyTorch 用 pillow 来处理图像的加载，其图像的数据layout是（H，W，C），而 pyTorch用来训练的数据需要是（C，H，W）的，因此需要对数据做一些转换。另外，train 数据集做了一定的数据预处理（旋转、明暗度），用于进行数据增广，也做了数据打乱（shuffle），而 validation则不需要做类似的变换。\n",
    "\n",
    "> 这里有一些地方需要注意一下，RandomRotation 我们使用了 expand 所以每次输出图像大小都不同，resize 操作要放在后面。pyTorch 中我没找到如何直接用灰度方式读取图像，对于汉字来说，色彩没有任何意义。因此这里用 Grayscale 来转换图像为灰度。ToTensor这个操作会转换数据的 layout，因此要放在最后面。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from multiprocessing import cpu_count\n",
    "\n",
    "transform_train = tv.transforms.Compose(\n",
    "    [\n",
    "        # tv.transforms.RandomRotation((-15, 15), expand=True),\n",
    "        tv.transforms.RandomRotation((-15, 15)),\n",
    "        tv.transforms.Resize((128, 128)),\n",
    "        tv.transforms.ColorJitter(brightness=0.5),\n",
    "        tv.transforms.Grayscale(),\n",
    "        tv.transforms.ToTensor(),\n",
    "    ]\n",
    ")\n",
    "transform_val = tv.transforms.Compose(\n",
    "    [\n",
    "        tv.transforms.Resize((128, 128)),\n",
    "        tv.transforms.Grayscale(),\n",
    "        tv.transforms.ToTensor(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "img_gen_train = tv.datasets.ImageFolder(\n",
    "    \"worddata/train/\", transform=transform_train, target_transform=target_transform\n",
    ")\n",
    "\n",
    "\n",
    "img_gen_val = tv.datasets.ImageFolder(\n",
    "    \"worddata/validation/\", transform=transform_val, target_transform=target_transform\n",
    ")\n",
    "\n",
    "batch_size = 32\n",
    "\n",
    "img_train = torch.utils.data.DataLoader(\n",
    "    img_gen_train, batch_size=batch_size, shuffle=True, num_workers=cpu_count()\n",
    ")\n",
    "img_val = torch.utils.data.DataLoader(\n",
    "    img_gen_val, batch_size=batch_size, num_workers=cpu_count()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "到这里，这两个数据集就可以使用了，正式模型训练之前，我们可以先来看看这个数据集是怎么读取数据的，读取出来的数据又是设么样子的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([32, 1, 128, 128]), torch.Size([32]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for imgs, labels in img_train:\n",
    "    # img_train 只部分满足 generator 的语法，不能用 next 来获取数据\n",
    "    break\n",
    "imgs.shape, labels.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到数据是（batch, channel, height, width, height）, 因为这里是灰度图像，因此 channel 是 1。\n",
    "\n",
    "> 需要注意，pyTorch、mxnet使用的数据 layout 与Tensorflow 不同，因此数据也有一些不同的处理方式。\n",
    "\n",
    "把图片打印出来看看，看看数据和标签之间是否匹配\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'寒'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD7CAYAAABqkiE2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nO2deYzd13XfP2f24SxcJQ5FSlwkUhQpWqY0lihLlmXLrhzViGEgMJwEqZK4EVokaZYWsVz/4RRIAacNkjhI61SIkziFa9lx3NpwU6eulxiGLdqkFi6mKFLiNiRnhpQ4w+EMh7Pd/vHe977fzDySb+bt8zsfgHjD31t+9933+517zrlnsRACjuOkl4ZqD8BxnOriQsBxUo4LAcdJOS4EHCfluBBwnJTjQsBxUk7ZhICZfcDMjprZcTN7tlzncRynOKwccQJm1gi8Brwf6AN+Avx8COGnJT+Z4zhF0VSmz30QOB5CeAPAzJ4HPgTkFQJmVrQkamjIKDXd3d3x7xthZvOO3UggJl9fqODUexYraG90znzPmdl1z3Wz8ecb643Gn2/+5rKQzypk3IVSqflOHi/09Qs5V3NzM+vXry94/Ddj//79F0MIt8w9Xi4hsB44k/h/H/BQ8gVm9gzwzNw3LnZS29vbAXjPe95DZ2dn3kEl39fU1BQ/V8enp6ev+/mNjY3x9XNfd70LVcJIr09+l5mZmVmvyfd5ei55zuR49Hp9VlNTE1NTUzccy9zvqeOaj8nJyXnfWZ8fQpg3thuh9yVfPz09PW9+k+NOfpe531Pnvp7QmDsfNyPfOMTceWxsbIzfId8caU41xqampjiO5PyJ5ubm+Nzc31bjWLt2LZ/+9KcL+i6FYGan8h0vlxC4KSGE54DnADZt2hQ++clP0tDQECdx7sQ1NDTEyUleMPoRVq5cCcCHP/zhyn0Jx1kClEsInAVuT/x/Q/ZYXtasWcOv/dqvlWkojuPciHLtDvwE2Gpmm82sBfgo8PUynctxnCIoiyYQQpgys98A/hFoBP4qhHC4HOdyHKc4yuYTCCH8A/AP5fp8x3FKg0cMOk7KcSHgOCnHhYDjpBwXAo6TclwIOE7KcSHgOCnHhYDjpBwXAo6TclwIOE7KcSHgOCnHhYDjpBwXAo6TclwIOE7KcSHgOCnHhYDjpJyq1Rh0aoOrV68CsH//fgBaWlp429veBkBbWxsA4+Pj8XHFihVVGKVTTlwTcJyU45pAynn55ZcBOHnyZDw2MTEBwB133AHAiy++CGTKau/cuROAW27JlK9fvXp1pYbqlAnXBBwn5ZSlDdlC6e3tDfv27av2MArmypUrAHR0dACL65JTTdS34cUXX2TZsmVArr/DK6+8Mq8xiv4/OTkZ+zx0dXUB0NnZGb//fffdB2SaZji1h5ntDyH0zj3u5sACGRoa4vjx40Cui8yOHTvi3/WAmrhs376d7u7uWc+1t7fT398P5EwEddyZnJyMjkQde/PNN6P5MDY2BsCTTz4ZP8upfdwccJyU45rAAuno6IhOMZkD9aQFJJmrBQDcdddd3HXXXQBs2rQJgBMnTgAwODjIyMgIMLvfn/4+d+4cAD/84Q8BeOKJJ8ozcKekuCbgOCnHNYEF0tzczMaNG6s9jIqwYcMGIOfoO3v2LMeOHQPgjTfeADLbhtKE9HjmTKYh9cmTJ6M24dQuixYCZnY78LfAWiAAz4UQPmNmq4AvAZuAk8BHQgiXih+qUy10c2/atInly5cDcO3aNQAGBgais1ARhnIanjp1yoVAHVCMOTAF/NsQwg5gD/DrZrYDeBb4dghhK/Dt7P8dx6lRFq0JhBDOA+ezf4+Y2RFgPfAh4PHsyz4PfA/4eFGjdGqGlStXAnDPPfcAs7cIW1pagFx8wVtvvRVjKjo7Oys9VKdASuIYNLNNwG5gL7A2KyAA+smYC/ne84yZ7TOzfRcuXCjFMBzHWQRFOwbNrBP4e+C3QwiXk9FzIYRgZnlDEkMIzwHPQSZisNhxOJVFAUcrVqyIEYiKOlTQkDQDp7YpSgiYWTMZAfCFEMJXs4cHzGxdCOG8ma0DBosdpFN73HbbbUAmRuCVV14BcnECWgja2trqNoYiTSzaHLDML/054EgI4Y8TT30deDr799PA1xY/PMdxyk0xmsAjwC8BB83s5eyxfw98GviymX0MOAV8pLghOrWIVvj169dz6NAhAFpbW4GcGdDV1UUtJKg5N6aY3YEfANdLn/N4UcepEzxi0CmKlpYWenp6ADh/PrMppHTjO++8MwYQObWL5w44TspxTcApmq1btwIwOjoK5AqObNu2rWpjcgrHhYBTNKpA/Pjjj1d3IM6icHPAcVKOCwHHSTkuBBwn5bgQcJyU40LAcVKOCwHHSTkuBBwn5bgQcJyU48FCTk0xPT0N5KoZNzc3s2rVKiAXkdjc3BwboV6+fBmA5cuXx/cqd8EpDBcCTsnQTahqw6Ojo/FYMs1Yx1Sl+PLly7E60cDAAACvvvoqkClaovqEes2yZcvi5w0PDwOwatWqmN589913A8TEpubm5gULBqVA11ufycXg5oDjpBzXBJybotZjZhZX5b6+PgCOHj0aV3TVHdSKPTk5GVfnO+64A8iUINPz0ggaGhriiq7PUn+DxsZG3nzzTYBY1bihoSG+Tp8xNDQUTQMVrpW20NTUFFd0nbuxsTF2ZNZzMzMz8Xm1VlN15a6uLtasWQPkmrLkY3AwU03v1ltvve5rag3XBBwn5bgmUAdcvnw5b/PQUqIV9eLFi7EQyFz7PFk0VP0E3nrrrfhepRBrFV2xYgXr168Hcnb6zMxMtM9lszc1NXH27Nn4eZArZNrV1cWRI0eA3Aq/atWq+Bn6/ObmZoaGhgDiCi+7/uDBg3G8SRtf2odW/3zoMxsaGmKrdWlG09PT0UGp5370ox8BsHr1ah588EGg9qsuuybgOCnHNYEycvHiRSC3gmmLa/Xq1axbtw7I2cCjo6NxxVBDT71+bGws5uxr1WpubmZ8fBzIrUKtra1xFdajzt3c3Bztea2i4+PjcbVM2t1ahWXjaxWdmZmJdre6DDU2NsaGpZs3bwZyK3FPT0/UDsTb3/72vHMl2zsf+nyt3PIXFMq6devieLVzcfny5Vm7E5DRHPT3XD/ElStX4nz/5Cc/ia/XnGrOpCX09/fH30oNbJctWxY1Ov1mtYALgRLR398P5Bxmo6OjUcXVxaMbs6urK6rCumiSKv+lS5n+rarMs2rVKk6cOAHk9s87OjriDabPGh4ejjeubnTd5J2dndFRpgv85MmT8Ry6sXp6euLn6fN1ob/++uvx++rmaG5uZteuXUCmpmA5WejNL/KZUjcSOgAHDhwAMk1VIdOAVWaSTISmpqZ4o+s30xxMTEzEY/rNOjs742dobh944IGym3o3w80Bx0k5qdIErl69GltqS71eDHO3kU6cOBG78Eh17ujomOVUgtmONa3UWgW6urri67QFtXPnzvi+jo4OILcCT09P89BDDwE5dTOEMC+4JV/Qi1aohoaGqJ5rtc8XHHPvvfcCs80NaTcjIyNl1wCqwZYtW4BcBeX29vZ47dxyyy1ARhOQCSdTRVuhW7ZsiVrh/v3742uk+ek3aGtri79ztTQC1wQcJ+WkQhN46aWXADh79mzc0nnHO94B5JxfhXLixInouNN21tjYWHTiaeXVSgk5rUN26Jo1a3jb294GUHCvPjnd5ICamJiY53TLt4rnO6ZxvOc97yno3EJVhdOANKPk7ygNUNuS27Zti7+ffEK6vsyM22+/Hcj9/n19fZw8eRLI+YJOnToVezhqfnVtQM5JLO2jHH0cStGVuBHYB5wNIXzQzDYDzwOrgf3AL4UQJoo9TzHIMTMyMhKjvuQck+e5UA4cODBLFYaMWq2LQd7nW2+9Nd6kd911F5Bz4OVT2wtFzj09OuVB6r1+4/b2dnp7ewHYvn37vNffyNGo6+Cee+6JAkS7NlevXo1OXzkjZUZ2dHTEHSY5cLUYlJJSmAO/BRxJ/P8PgT8JIdwFXAI+VoJzOI5TJoptTb4B+OfAfwR+N9up+L3AL2Rf8nng94HPFnOexSIpqv3grq6u6PBZ6EqqqLjR0VHGxsaAnOo/MzMTzQpJ/Z07d0Yn0VzSkJlW70ijk7Nu48aNeTWAhaLPSzoBFTOi6+nYsWNAJi9DGqscztPT01GzLBXFagJ/CvweMJP9/2pgKISgOMw+YH2+N5rZM2a2z8z2STVyHKfyLFoTMLMPAoMhhP1m9vhC3x9CeA54DqC3t7cs/au1badVfO3atXEbbqHbMZLSLS0tcZXQsWXLlsVtOn3u9bSAWuLQoUNx+8+ZjXxHcp7KUVgO5EwU8i80NTVFn4S02WPHjpVcEyjGHHgE+FkzewpoA7qBzwArzKwpqw1sAM4WP0zHccrFooVACOETwCcAsprAvwsh/KKZ/R3wc2R2CJ4GvlaCcS4K2VNasaempqIvYKF2uWz+Xbt2Rc+uAj527drFpk2bSjHkiqJtTef6zN2GrQQKJNu6dWvchlb48tDQUNxNKNVOQTniBD4OPG9mfwC8BHyuDOe4IUoCSUbXQWbPt9gf9e67745psfVe004Rh05tsmLFCvbs2RP/how58PLLLwO5665Y07MkQiCE8D3ge9m/3wAeLMXnOo5TfpZkxKA0AaVrSiMotXOnXjUAp35Q5qS0to6ODn7wgx8A8L3vfQ+Ap556KjoyF4PnDjhOylmSmoCyt2T/K9AiGQfuOPXI9PR01HCVy3Cj8miFsOSEwJUrV6InVbsCQrH7aUAmkFJhFccAuYi0qakp7rnnnsoPLkvyIk5WEoby7svXM2+88UY0Q5WvMjg4WNS17eaA46ScJacJdHZ2zquNpxXn3LlzdREhp/gDbQstJtdAlYJfe+01IBOVptVVjtMzZ87EjDUVxwghxPMuNjJN4z98+HCMy5AK29HRETP0lCabLPaiVFmtbLt373atAOL8NDY2RjNXv+O5c+dmpR8vFNcEHCflLDlNAHLVbrXKyXbS6lirKJ9cdfbVxebRRx8tuPiIUEELbR0lq9s+8MADQCaXQk5UFbYYGRmJJcS0Uksj6OrqilqCbPhLly7FrEo5XvW+U6dORRtf2llbW1uMVEzW49dKp8/Sa4aHh3nssceA0pTfUhRpd3f3gue0mkijeve73x3nWUFDFy5cKCpwbUkKgblRgboQdWNUm6NHj8ZwUKl04+PjUY1WPcFkY43FcqPS1hs2bIj1DCUwx8fHo7Cc2/Sjubl53ucNDw/Hi1KO2KTqqotX33NqaipesDLX2tvb428k00ef0d/fz4svvgjkqirNzMzMG8dtt902b56OHz8OZISN5lvmkZnFSj7XK4NeiyxfvjzGDijxqL+/PwoECfiF4OaA46ScJakJSI1W3X9tg5WiKEQxHDp0CMio/VK5tQKOj49HTeW+++4DKtvUMll9WU45lbRSk83h4eFYqk3mAOQKcMipp9W8oaEhfr+5TVEgp6W0tLTEz0s2V4GMGqy6fNr6TVY9ltmwefPm+PtqjKdPnwYyc6vVU1rL5cuX55UQa21tjZqLKgAvtAZlJZjbHq61tTVqa4tpqe6agOOknNoTcyVAdfBlV6pDTqXR6ibbWu2rmpqa4tiSvQYUH17tjjRCK3Uy0Ejl2cSFCxdi8JHGrdVzeHg4rux6bmZmJjr/kpGcmgdpJMntTD2nVW5sbCw+r5X71VdfnaUpQG5uW1tb46qvc3d2dkZtQuMPIcSx67dLajD333//rM+vFqqeLSYnJ+M8L2Y72TUBx0k5S1ITkFSslgYAGftVW35ahZL2sOxneXj37NlTMxrAjZgbuNPZ2Rm99oWskFeuXIk27eHDh4HMzoQKZOhR52lqamLv3r1AztZXzwjIbYm99tprUbPQKi5fRU9PTxybis+2trbO0iz0Pv1G+u2EmcVzqeV4tZC2JF/N4OBg3OXRtb8QX8aSEwJXrlyJe9/VdAS2tLREZ9TcbsArV66Mx7QNWOs97G/EQtTjzs7O+N3luLt8+XJ0QubrA6GbXtGe+V5z1113RaGi7UA5F3fv3j2vq3Oyft8//dM/ARlHsgSxzDXR2NgYK/rouqqW0Nb55fju6uqKzUm0/boQIeDmgOOknCWnCXR2dsbyX9Vk3bp1sZ58JVHA0c1ab9cCyZRYOfHyka9Wf77PUoegG5FcIRVUphJeJ06ciCaFtAmZDx0dHXF7UZV/q6UJSKvR48DAQNQKdEzmQSG4JuA4KWfJaQKQ7g4/CkhSwFEtOxu1Kk9NTUVnoWzbSqL8imSJLtVhUFvxW265JYZzL7R/ZblI5gkoWGgxFaSXpBBIM1IHdVHUshBIRrdpr17JSuXovrsQZMq9//3vB6ofG5Dk9ddfB3KO1c7OzqLyS9wccJyU45rAEkPqtLZJaxll8Q0NDUVzQOpstTUBUUsagFB2ZDL6UU7LZA5Iobgm4DgpxzWBJYacVwoaqWW0ar3jHe+IUYHyZcxt0unkkENVAWZXr16NGou2NhfivCxKCJjZCuAvgXuBAPwqcBT4ErAJOAl8JIRwqZjzOIWjfW2FzJab6enpopuwrFixIlaDUtqwHIS1mMpbTcbHx+MNn0xuUpzFYoR/sebAZ4BvhhC2A/cBR4BngW+HELYC387+33GcGmXRYtbMlgOPAb8MEEKYACbM7EPA49mXfZ5Mj8KPFzNIp3AUPy9NQBFwxVTsTZYP00qjOPrR0dEYnSb1Xqv3wMBAXK20hdXR0RGj2/Rcd3d3jMJbTAJMmmhra4vzrVJpMzMzcS4Xo5UVowlsBi4Af21mL5nZX5pZB7A2hHA++5p+IK9xYmbPmNk+M9snO9BxnMpTjLhtAu4HfjOEsNfMPsMc1T+EEMws5HtzCOE54DmA3t7evK9ZCFrxXnjhhbhyvetd7wJyTR2XAjdb2eUQOnr0KJBLv71eXXptN2mlvnz5Mv39/UAumEf2emdnZ1ypk2XGlI2nWHytRqdOnYq/hRyWK1eunNc6vrGxMb5H59QW54oVK+I5b7b9Je1nbuDMxMRE/E7KV6hGZGKp2LRpE5BLlT98+HCct8VoUMVoAn1AXwhhb/b/XyEjFAbMbB1A9nGwiHM4jlNmFq0JhBD6zeyMmd0dQjgKPAH8NPvvaeDT2cevlWSkCbRKnDx5Mq4wCja5cOFCDKLYt28fkKnVDvVnZ2rVHxsbi4VJ9D137tyZN1NQtqHeqxj462kCssVlpzc2Nsatp1WrVgG58mLd3d3zOgUNDQ3FFVqrqzSvnp6e6ENQPkdbW1tcsXXu6enpqG2omIe2DBsaGqImoNJmGzdujJqD+gi88cYb8bpQjQa9ZmxsLG496hpYvXp1zDqs1x6VqrMwPT0df78bZWNej2Lvit8EvmBmLcAbwK+Q0S6+bGYfA04BHynyHPPQD3rkyJG4XaILq6mpKaqluihU46+alYYKQZViNF4Js2XLlsWKO/rhrxfJpptTN7JU7vHx8bxReAudE523kHZuC+nzIIEggaKtwunp6Xm1GhU7n3zf5ORk/M4yN3TDr1u3jjvuuAPIOTknJydj2nW9CgFx5513xtwLzcdCKEoIhBBeBvIlcT9RzOc6jlM56ks/zrJ7924gEx0lB1hy5ZP6KNVYTqxqawLK+nr99dej5pJs0yXVWSuZ1PFHHnmkYFNGkXbJir+QMSOSVYNriddeey3OjVYyObrGx8fnpfDeeuut8feWttTc3Bz7NEhb0fvquXRbIVy8eDFe65V2DDqOswSoS01AdHR0zLONk44k9R/Qlkq1kLNLWktbW1u08WWnj4+Px5VOK5ceFyLd5biTYzDpwKs1TUDbdj/+8Y9n+T8g5+O57bbbYgjx3DlLK9IYVfDk4MGD81rAL4S6FgKdnZ3xS+uib21tjReJGk1WOxlFaqkq6m7fvr1sKapyommvXOaGTIxaQt2Lr127FudDZoCEwMzMTHQwpv3mF2pio0Vleno6XmOLiYlxc8BxUk5dawL33ntvjHjTfnFra2tUq6utAQiZJZVEGoBW1EuXLkUtoVYcZck97bmOUu3xj46OFp2luNTQVrJMvz179sSmu4spM+aagOOknLrWBCAXiSYJ2NbWFiPG0oiiCGVja7W4du1atMGroZnkQ3YszNdO9Ltu27ZtUc6upYzmQxrv3Xff7YVGHcdZPHWvCShARPZuQ0PDrPrxaUMBRtpWHBoaAjJ5+gq3rhVNIBniOjfmXSubdlScHNKSiskcTFL3QkCx9OqM29bWlmohILVaTlFtDS5fvrziN/+BAwdiLkAyqUfx+xLca9asiVWGFfmm19dy34RqoaIi2iI8d+7cgvI05uLmgOOknLrXBKQK1Vo0XLVRBx2lmF67di1GUpYbOSNPnToVS2ApU+/hhx+OQT/SVrZt2xbNAam4hdbP13f67ne/G/+W40yRhgtpzlkPJE1fWFyvgSSuCThOyql7TcDJj1YHbcMNDQ3F+grlbpmuwJ8QQnT+qR6+VnrIrdTFoECit956K66Qyj+ohRb15UAagDSu8fHxWENjMbgQWKKo7p8ulMbGxlioQ863jRs3xh0DRVyuXLkyqulS2xWPXuhetIp/DA4OxpteztqpqalZ3YghY9JJWCw0JkDmTlNTUxQ+p06dAnK7D+vXr4/f6XoVluoJCXYJvWIbzbg54DgpxzWBJYbiypW+LLOgp6cndieSs+78+fMx+1Kr8uDg4LxMRJkPW7ZsiSuqXjMyMhJXIhUwUU+Ctra2GLmoVOGXX345jlWfv379+riVqT3wZHZosiQYZLQbaRha7bds2RJrHErjkalw8eLFqB1oC3X79u03m8qScejQIaCwkmyFoDqT0qiK3RJ3TcBxUs6S1ARK0XWn3GhV08pXqkrIL7zwwqzPV02FDRs2xNVT5dampqbiMc1ZY2NjXNGlOaj018DAQBynHFGtra1Ri5B2IG1h/fr1Mb5dPopk7ruiPSG3hSitIpk5qJVvbn8DmO+30HdNMjk5yXe+8x0gF2AzPDwci49qJS0m/j4f0speeumleEzO0MXWRrh06VKsIC0HaOojBufS19dXswlEUnlPnz4dHVqioaEhen31mCzTLYGWdKZJtdVNMjw8HCPsHnzwQWB2JV19XqGRgxIMGuvw8HBU73UDJW8c3dw6T5J85dHzkS9tWIJkIZ12kzQ3N/Pkk08CuWrDJ0+ejAJB6vrMzEwcu14nobRq1aooXAqdPzlgk6aQKiXLQalqSYUyOTkZoytLVTHLzQHHSTlLRhOQhD1x4kTNRohpld64cWNcHeTYamhoiKuOVmA919TUFDUBrVCDg4NxVVbceG9vb0lr6OucS2m/XWr49u3bo3NQ5s/Bgwej1pP8XfR/XWMyR5LmTD40f9KCrl69GhO6dM6FYmZR/S82UlC4JuA4KafuNQE5tn76058CmZWy1h2DXV1d0Wa/Efoek5OT82zqF154Ia5WDz/8MFA7ZcPqDTkVH330UV555RUgt82oLkUTExPxWjt48CAA733ve/P6P4SuP2kfY2NjUduT43ahtLa2Rj9MqZzJRWkCZvY7ZnbYzA6Z2RfNrM3MNpvZXjM7bmZfyrYocxynRlm0KDGz9cC/AXaEEK6a2ZeBjwJPAX8SQnjezP4C+Bjw2ZKMNg+y0xT22tHREfuyqVNRvXIjTWbPnj0VHEl62LlzJ5DTquQjuHjxYvTHnDt3DsiEJxfiodfKPTMzE3c/FM69ULq7u+MuTKmyQovVJ5qAdjObBJYB54H3Ar+Qff7zwO9TRiEgtU2OtpmZmdiUQZN03333Aa4uOzdHKraq9+pxYGAg1vtX/MSrr74azTTt2eeL+5AJEEKIZoASqq5cubJgs1XnWuyW6VwWbQ6EEM4CfwScJnPzDwP7gaEQgkRUH5B3097MnjGzfWa2T0ksjuNUnmLMgZXAh4DNwBDwd8AHCn1/COE54DmA3t7ecJOXXxdJWTlfkm3IlM2mbZ4HHnhgsadxUs7atWtjKTtdV2fOnImRiCpvt2XLFoCYxwA5J2CypqICvA4ePBgdu4WivIrFdBvKRzGOwfcBJ0IIF0IIk8BXgUeAFWYm4bIBOHu9D3Acp/oU4xM4Dewxs2XAVeAJYB/wXeDngOeBp4GvFTvIGyFHi1b/tra2efnqCtBwnGKY29nq8OHDsR6DAneSGsBcGhoaYiixthtfe+21GOBVaAixto6Vm1BscNyihUAIYa+ZfQV4EZgCXiKj3v9v4Hkz+4Pssc8VNcKboOgteVuTE60U2K1bt5ZzCM51uHTpUsE5A/WEFhk5nG+GvPnd3d1xYVIiUXt7+4KjPDWnpYqDKWp3IITwKeBTcw6/Adw8EsZxnJqg7iMGJVG1d3vx4sW4haKYemW8OZVBBTzOnTsXU5nT3EpMmunU1FT8W6bFxo0bF7x1rZyHUrVq99wBx0k5da8JJAtYQKbwhLZcaqXdVj3R19cXt11lcxba4l32rmLr33zzzVgI5J3vfOesz0wT8kktW7Ysbg3KX7WYALZSaQCi7oWACnUobNjMXP1fBCqscebMmehklQr/vve9D8hf8COJil3oQr969WoUDIrifPe7313ikdc+WozuvPNOjh8/DtSWMHRzwHFSTt1rAqqqK8fg2rVrS14rbimjVV+lti5duhTNAc2jqhMrjh5yFYXNLEbEaf9adHd3x7qDSupqb2+PSTcqigI5U6JQ06NeKUXDlVLjmoDjpJy61wSSVXIhtx3jFIZWY9moTU1NcQ7lZ/nxj38MZOZajixpAsnSZwqiSUZxSpuQxnHo0KFYAEavb2lpib/jrl27gFwU3Lp162KmqKoeP/zwwwvS9iYmJqJPYtu2bcDS1zgWgmsCjpNy6l4TePTRRwFiNtfNij86+XnkkUeA2d19VGpLOzD9/f3RXyDbduvWrbNse8hldh47dixqDsqy6+7ujjHvKtLR2NgYt73UE0E5+6+88krM0V+stnfmzJkYwKQAMidH3QsBqaLaeipVemXa0Dwmt66Uep1U77Xvf6Nt2Pvvvx/IbInp5pb63t3dHVV/PU5PT8fXqaafKkYNDw/HaseKPlwMEmQ/+MEPgEw/BtdWSZIAAA6XSURBVMUupL3YjJsDjpNy6l4TEK4BlJ58de0LyQpUZmeyXVihz6slmMyCd77znUWr8HfeeWfconz11VeBTHkvaTjvete7ivr8esc1AcdJOUtGE3CWFtrKK1UxTW05njlzBsj4ARQgJT/IPffcU/K4/HrAhUANodZU8prL6z4yMhI70Uo1VgLKUmWhjTpvxr333gvk8hsOHz4cdzFURfjw4cNxt6lUzT7rATcHHCfluCZQI4QQ+OEPfwjkMu42btwIZDQC5UZo+22pawLl4qGHHgIyGY5yFqoa9ejoaEyD1lalWojfLIOynnFNwHFSjmsCNcKhQ4diQMuqVauA3Opz9epVduzYAeTq2jvFsWvXrpgbIX9Le3t7rB6s9mOKTtyxY8eSLZHmQqDKyDn14osvxpRceavlxOru7k51ZZ5ysHr16hgR+c1vfhPIREQqdkG/i+IVxsbGlmw8gZsDjpNyXBOoMnIGQi69VRqA9sgfeugh1wDKgGIHtB145syZ2CpM8QLath0dHY2doJda0RrXBBwn5bgmUCW0BaWVpqenJ25fKZVX/RPmpuo6pUFFTdQ6bGBgIGoCSl9WhuHMzEwsoFps269a46aagJn9lZkNmtmhxLFVZvYtMzuWfVyZPW5m9mdmdtzMDpjZ/eUcvOM4xVOIJvA3wJ8Df5s49izw7RDCp83s2ez/Pw78DLA1++8h4LPZR2cOCkqRF/qxxx6ju7u7mkO6LioCqpXzeqioqAKbhoeHC+7XV03UK3B6ejoGDkkbU1m0EELMO1CPC2kL9c5NhUAI4ftmtmnO4Q8Bj2f//jzwPTJC4EPA34bMDL5gZivMbF0I4XypBlzP6MI6evQob731FpCrSV9LAkC1/LRXrv307u7uaKLs3r0byBXrgFyko77byMhI3PZU66xadHBqTC0tLTE2Q8JZQsHMYs8AOQbf/va331Qw1gOLdQyuTdzY/YBSvdYDZxKv68sem4eZPWNm+8xsnwI0HMepPEWLsRBCMLOwiPc9R6aVOb29vQt+f61z6tSpWNdOq4lWxdOnT0cNQNV1a4WhoSH27t0L5DoJabVrbm6O6rEcmiMjI/F5rZ7Ka2hra4tlwmQi9PT0xHqDev3KlSurqgnJ0bdz587YY0ELk5yybW1tjI2NAbnSZyGEmFtQz87bxWoCA2a2DiD7OJg9fhZIloHZkD3mOE6NslhN4OvA08Cns49fSxz/DTN7noxDcDgt/gCtkKqpf+TIkbjdpBVP9vTExERN1L0PITA0NATkxvj666/HcWt1liYzPT0963WQWQG1svf29gK5EmSXLl2KhTvUO+DQoUPR1yD728zi9mg1q0Xv3r07ljf70Y9+BOQ0HmlxQJyzAwcORH+J6h+oKGo9cVMhYGZfJOMEXGNmfcCnyNz8XzazjwGngI9kX/4PwFPAcWAM+JUyjLkmUXluCYHx8fFYo0/OMyWjmFksu51s7VUp1Hz05MmTsYCJnJYjIyNxnBJaUoMnJiai2iuht2PHDnbu3AnMrz+4cuXKGGWnm6u/vz964/V49OjR6FSUA7HURUUKRUVbFMeh31NRnJDbFWhubo5mjpyhEqDr16+P8Qe1TiG7Az9/naeeyPPaAPx6sYNyHKdy1P/+Rg3Q19cXVwSpy1NTU3H1kDqtFTaEEOv3a2+9krXvZYqMjIzE8UrFHRoamhUlB7Oj56TCy1TYunVrQRWI1awkX0POtWvXxnRezYsKftx+++1Vqfsnh5+0oePHj9PX1wfknIATExNRg5IGoHGPjIzEeaz1StieO+A4Kcc1gRLQ398fbXytEsuWLZu18kPOjm5vb4/+gmp0v9EWXrJEmVbgvXv3xu8ipC00NTXF6EGVPitFNeDly5fPWy0VnTc4OHjDbkflRppLW1tb1Oz0ODMzM8tpCrMdiTr2+OOPV3LIC8Y1AcdJOa4JlIDGxsa8TTK1Smi1V1DNbbfdVpVdgRuhijqPPfZYDJjR9pc85W1tbdGHUe4yZ7XWODSZObhv3z4gE1A0tx27roOJiYlYlUhbqOqlUGu4ECgBK1asiBeJ9pOHh4djrUDtHSvxpBbj50VLS0us0S91Vo1DkxV3a/k7lAv9xoom7O/vnxU/ADnB39DQEJ87fPgwkMnFkJNVzVXztXqrNG4OOE7KcU2gBGzYsCEGhmirra+vL6qB1XRsFYNW/lrf4qo00vCam5uj+j/XHJyamorzd/HiRSBjXslJrMrF7e3tcYu1WiaQawKOk3JSpQkkY98L3ZpTJqCker7SUs3NzTHHXLZyT0/Pksg1d+aj7dGtW7fG4KC5HYoaGxujT0Crv5lFn4BqNoyOjsbrROHXyjBtaGjgyJEjQC70vLu7OxZGLVVYciquUt3IBw8ejPvcilHX41yk1it2fOvWrQs6pwuApc/9998fb07lDuiGn56enrdzYGbx+lM8QQghLkwHDhwAco7E1tbWGIkoBgYG4oJTKiHg5oDjpJwlt1ydPn06Zr1JNVc66/nz56M6Jmnb1NQUHTJS0QYGBmK8up6TiuY4orOzk56eHiCzXQi5fgXXrl2LzsJk7og0Aa3mra2t8ZrUc4osTWYuyjm7fv36ksdouCbgOCmnrjWBK1euRJtd0vPEiRNRsupRwS7Nzc3RZpPN//3vfz+2+1Z2XXd3d9QACsmQc9KLnHS6DkU+x/O1a9ficWUnTk9PRwdiPs1Br5ND+oEHHih5Y9S6FgLXrl2LiSa6qVtaWmKhDE2uVLTGxsb4nNSroaGhmMaq8ti1FrLq1C4qlqJrRslXExMT84qnXL58OV5/cig2NjZGx6EciRICa9eujVWdZXaUAzcHHCfl1LUmsHr1ah588EEgV812ZGQkOvjkfJFGIIchECXypk2bYhz3Umkm4VSeRx99FMg5CAcGBmLhFWkJb775ZtQ6dS2OjY3NanACOc1h9+7dFSlR5pqA46Qck/SpJr29vUHpmcUyPj4et//kL1D66/Lly6NWoOIf27dvdx+AU1WkCUg7lY+g1JmaZrY/hNA797hrAo6TcuraJ5CPtrY2HnjgASBTDhtyefFpzIF3ah9ppdXqYrTkhECSUu+nOs5SxM0Bx0k5LgTycPXq1Vlpx46zlLmpEDCzvzKzQTM7lDj2n83sVTM7YGb/08xWJJ77hJkdN7OjZvZkuQbuOE5pKEQT+BvgA3OOfQu4N4TwNuA14BMAZrYD+CiwM/ue/2pmjdQJly5d4tKlS5w9e5bz589z/vx5pqamYnaX4yxFCulF+H0z2zTn2P9N/PcF4Oeyf38IeD6EcA04YWbHgQeBH5VktGVGyUKeNOSkiVL4BH4V+D/Zv9cDZxLP9WWPzcPMnjGzfWa2TyWcHcepPEUJATP7JDAFfGGh7w0hPBdC6A0h9NZLC2fHWYosOk7AzH4Z+CDwRMjFHp8FkjG4G7LHHMepURalCZjZB4DfA342hDCWeOrrwEfNrNXMNgNbgR8XP0zHccrFTTUBM/si8Diwxsz6gE+R2Q1oBb6VTb99IYTwr0IIh83sy8BPyZgJvx5C8M12x6lhllwWoeM4+fEsQsdx8uJCwHFSjgsBx0k5LgQcJ+W4EHCclLOki4o4Tr3z7LPPApnqxSLZ5RiYlfI+d7evkArargk4TsqpiTgBM7sAjAIXqz0WYA0+jiQ+jtnU8zg2hhDmJerUhBAAMLN9+QIZfBw+Dh9Hecfh5oDjpBwXAo6TcmpJCDxX7QFk8XHMxscxmyU3jprxCTiOUx1qSRNwHKcKuBBwnJRTE0LAzD6Q7VNw3MyerdA5bzez75rZT83ssJn9Vvb4KjP7lpkdyz5WpPSwmTWa2Utm9o3s/zeb2d7snHzJzFoqMIYVZvaVbE+JI2b2cDXmw8x+J/ubHDKzL5pZW6Xm4zp9NvLOgWX4s+yYDpjZ/WUeR3n6fYQQqvoPaAReB7YALcArwI4KnHcdcH/27y4y/RN2AP8JeDZ7/FngDys0D78L/A/gG9n/fxn4aPbvvwD+dQXG8HngX2b/bgFWVHo+yFSnPgG0J+bhlys1H8BjwP3AocSxvHMAPEWm0rYBe4C9ZR7HPwOasn//YWIcO7L3TSuwOXs/NRZ8rnJfWAV82YeBf0z8/xPAJ6owjq8B7weOAuuyx9YBRytw7g3At4H3At/IXlQXEz/4rDkq0xiWZ28+m3O8ovNBrmz9KjK5Ld8AnqzkfACb5tx8eecA+G/Az+d7XTnGMee5DwNfyP49654B/hF4uNDz1II5UHCvgnKRba6yG9gLrA0hnM8+1Q+srcAQ/pRM4daZ7P9XA0MhBLU+qsScbAYuAH+dNUv+0sw6qPB8hBDOAn8EnAbOA8PAfio/H0muNwfVvHYX1e8jH7UgBKqKmXUCfw/8dgjhcvK5kBGrZd1DNbMPAoMhhP3lPE8BNJFRPz8bQthNJpdjln+mQvOxkkwnq83AbUAH89vgVY1KzMHNKKbfRz5qQQhUrVeBmTWTEQBfCCF8NXt4wMzWZZ9fBwyWeRiPAD9rZieB58mYBJ8BVpiZUr0rMSd9QF8IYW/2/18hIxQqPR/vA06EEC6EECaBr5KZo0rPR5LrzUHFr91Ev49fzAqkosdRC0LgJ8DWrPe3hUxD06+X+6SWSbT+HHAkhPDHiae+Djyd/ftpMr6CshFC+EQIYUMIYROZ7/6dEMIvAt8l1+OxEuPoB86Y2d3ZQ0+QKR1f0fkgYwbsMbNl2d9I46jofMzhenPwdeBfZHcJ9gDDCbOh5JSt30c5nTwLcIA8RcY7/zrwyQqd81Eyat0B4OXsv6fI2OPfBo4B/w9YVcF5eJzc7sCW7A95HPg7oLUC5387sC87J/8LWFmN+QD+A/AqcAj472S83hWZD+CLZHwRk2S0o49dbw7IOHD/S/a6PQj0lnkcx8nY/rpe/yLx+k9mx3EU+JmFnMvDhh0n5dSCOeA4ThVxIeA4KceFgOOkHBcCjpNyXAg4TspxIeA4KceFgOOknP8Pgrep39TusOgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.imshow(imgs[0, 0, :, :], cmap=\"gray\")\n",
    "classes[labels[0]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练代码-模型构建\n",
    "\n",
    "pyTorch 中使用静态图来构建模型，模型构建比较简单。这里演示的是使用 class 的方式构建模型，对于简单模型，还可以直接使用 Sequential 进行构建。\n",
    "\n",
    "这里的复杂模型也是用 Sequential 的简单模型进行的叠加。\n",
    "\n",
    "> 这里构建的是VGG模型，关于VGG模型的更多细节请参考 1409.1556。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyModel(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyModel, self).__init__()\n",
    "        # 模型有两个主要部分，特征提取层和分类器\n",
    "\n",
    "        # 这里是特征提取层\n",
    "        self.feature = torch.nn.Sequential()\n",
    "        self.feature.add_module(\"conv1\", self.conv(1, 64))\n",
    "        self.feature.add_module(\"conv2\", self.conv(64, 64, add_pooling=True))\n",
    "\n",
    "        self.feature.add_module(\"conv3\", self.conv(64, 128))\n",
    "        self.feature.add_module(\"conv4\", self.conv(128, 128, add_pooling=True))\n",
    "\n",
    "        self.feature.add_module(\"conv5\", self.conv(128, 256))\n",
    "        self.feature.add_module(\"conv6\", self.conv(256, 256))\n",
    "        self.feature.add_module(\"conv7\", self.conv(256, 256, add_pooling=True))\n",
    "\n",
    "        self.feature.add_module(\"conv8\", self.conv(256, 512))\n",
    "        self.feature.add_module(\"conv9\", self.conv(512, 512))\n",
    "        self.feature.add_module(\"conv10\", self.conv(512, 512, add_pooling=True))\n",
    "\n",
    "        self.feature.add_module(\"conv11\", self.conv(512, 512))\n",
    "        self.feature.add_module(\"conv12\", self.conv(512, 512))\n",
    "        self.feature.add_module(\"conv13\", self.conv(512, 512, add_pooling=True))\n",
    "\n",
    "        self.feature.add_module(\"avg\", torch.nn.AdaptiveAvgPool2d((1, 1)))\n",
    "        self.feature.add_module(\"flatten\", torch.nn.Flatten())\n",
    "\n",
    "        self.feature.add_module(\"linear1\", torch.nn.Linear(512, 4096))\n",
    "        self.feature.add_module(\"act_linear_1\", torch.nn.ReLU())\n",
    "        self.feature.add_module(\"bn_linear_1\", torch.nn.BatchNorm1d(4096))\n",
    "\n",
    "        self.feature.add_module(\"linear2\", torch.nn.Linear(4096, 4096))\n",
    "        self.feature.add_module(\"act_linear_2\", torch.nn.ReLU())\n",
    "        self.feature.add_module(\"bn_linear_2\", torch.nn.BatchNorm1d(4096))\n",
    "\n",
    "        self.feature.add_module(\"dropout\", torch.nn.Dropout())\n",
    "\n",
    "        # 这个简单的机构是分类器\n",
    "        self.pred = torch.nn.Linear(4096, 100)\n",
    "\n",
    "    def conv(self, in_channels, out_channels, add_pooling=False):\n",
    "        # 模型大量使用重复模块构建，\n",
    "        # 这里将重复模块提取出来，简化模型构建过程\n",
    "        model = torch.nn.Sequential()\n",
    "        model.add_module(\n",
    "            \"conv\", torch.nn.Conv2d(in_channels, out_channels, 3, padding=1)\n",
    "        )\n",
    "        model.add_module(\"act_conv\", torch.nn.ReLU())\n",
    "        model.add_module(\"bn_conv\", torch.nn.BatchNorm2d(out_channels))\n",
    "\n",
    "        if add_pooling:\n",
    "            model.add_module(\"pool\", torch.nn.MaxPool2d((2, 2)))\n",
    "        return model\n",
    "\n",
    "    def forward(self, x):\n",
    "        # call 用来定义模型各个结构之间的运算关系\n",
    "\n",
    "        x = self.feature(x)\n",
    "        return self.pred(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到，这里必须指定网络输入输出，对比 TF 和 mxnet 不是很方便。\n",
    "\n",
    "实例化一个模型看看："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MyModel(\n",
       "  (feature): Sequential(\n",
       "    (conv1): Sequential(\n",
       "      (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv2): Sequential(\n",
       "      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (conv3): Sequential(\n",
       "      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv4): Sequential(\n",
       "      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (conv5): Sequential(\n",
       "      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv6): Sequential(\n",
       "      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv7): Sequential(\n",
       "      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (conv8): Sequential(\n",
       "      (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv9): Sequential(\n",
       "      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv10): Sequential(\n",
       "      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (conv11): Sequential(\n",
       "      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv12): Sequential(\n",
       "      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    )\n",
       "    (conv13): Sequential(\n",
       "      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "      (act_conv): ReLU()\n",
       "      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)\n",
       "    )\n",
       "    (avg): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (flatten): Flatten()\n",
       "    (linear1): Linear(in_features=512, out_features=4096, bias=True)\n",
       "    (act_linear_1): ReLU()\n",
       "    (bn_linear_1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (linear2): Linear(in_features=4096, out_features=4096, bias=True)\n",
       "    (act_linear_2): ReLU()\n",
       "    (bn_linear_2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (dropout): Dropout(p=0.5, inplace=False)\n",
       "  )\n",
       "  (pred): Linear(in_features=4096, out_features=100, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = MyModel()\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练代码-训练相关部分\n",
    "要训练模型，我们还需要定义损失，优化器等。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_object = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(model.parameters())  # 优化器有些参数可以设置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time  # 模型训练的过程中手动追踪一下模型的训练速度"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为模型整个训练过程一般是一个循环往复的过程，所以经常性的保存重启模型训练中间过程是有必要的。\n",
    "这里我们一个ckpt保存了两份，便于中断模型的重新训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "gpu = 1\n",
    "\n",
    "model.cuda(gpu)\n",
    "if os.path.exists(\"model_ckpt.pth\"):\n",
    "    # 检查 checkpoint 是否存在\n",
    "    # 如果存在，则加载 checkpoint\n",
    "\n",
    "    net_state, optm_state = torch.load(\"model_ckpt.pth\")\n",
    "\n",
    "    model.load_state_dict(net_state)\n",
    "    optimizer.load_state_dict(optm_state)\n",
    "\n",
    "    # 这里是一个比较生硬的方式，其实还可以观察之前训练的过程，\n",
    "    # 手动选择准确率最高的某次 checkpoint 进行加载。\n",
    "    print(\"model lodaded\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 Loss 6.379009783063616, Acc 0.96, Val Loss 7.499249600219726, Val Acc 1.06\n",
      "Speed train 234.8145027420376imgs/s val 706.7103467600252imgs/s\n",
      "Epoch 1 Loss 6.250997774396624, Acc 1.1485714285714286, Val Loss 8.049263702392578, Val Acc 1.12\n",
      "Speed train 230.86756209036096imgs/s val 684.4524146021467imgs/s\n",
      "Epoch 2 Loss 6.0144778276715956, Acc 1.16, Val Loss 5.387196655273438, Val Acc 1.22\n",
      "Speed train 226.12469399959375imgs/s val 681.3837266883015imgs/s\n",
      "Epoch 3 Loss 5.589597338867187, Acc 1.0057142857142858, Val Loss 5.907029174804688, Val Acc 1.1\n",
      "Speed train 225.0556244090409imgs/s val 680.1839950846079imgs/s\n",
      "Epoch 4 Loss 5.402270581054688, Acc 1.1714285714285715, Val Loss 5.1295126434326175, Val Acc 1.4\n",
      "Speed train 224.72032368819234imgs/s val 682.6134263746877imgs/s\n",
      "Epoch 5 Loss 5.175169513811384, Acc 1.062857142857143, Val Loss 5.386006506347656, Val Acc 0.78\n",
      "Speed train 224.69030177626445imgs/s val 680.2019310561357imgs/s\n",
      "Epoch 6 Loss 4.945824640328544, Acc 1.2485714285714287, Val Loss 5.106301385498047, Val Acc 1.7\n",
      "Speed train 224.41665937455835imgs/s val 680.7435881896529imgs/s\n",
      "Epoch 7 Loss 4.78519496547154, Acc 1.2714285714285714, Val Loss 5.001887857055664, Val Acc 1.46\n",
      "Speed train 224.3077478615822imgs/s val 679.1943518703032imgs/s\n",
      "Epoch 8 Loss 4.681244001116071, Acc 1.5542857142857143, Val Loss 4.678347979736328, Val Acc 1.76\n",
      "Speed train 224.43681583771664imgs/s val 680.9421894383488imgs/s\n",
      "Epoch 9 Loss 4.594511456734794, Acc 1.977142857142857, Val Loss 4.734268209838867, Val Acc 3.48\n",
      "Speed train 224.43505593143354imgs/s val 679.8263336556521imgs/s\n",
      "Epoch 10 Loss 4.564881538609096, Acc 2.2142857142857144, Val Loss 4.5007019622802735, Val Acc 3.5\n",
      "Speed train 224.35177417457732imgs/s val 681.4680194348261imgs/s\n",
      "Epoch 11 Loss 4.359355766732352, Acc 3.797142857142857, Val Loss 4.303946963500977, Val Acc 5.1\n",
      "Speed train 224.22713261480806imgs/s val 678.4494364538398imgs/s\n",
      "Epoch 12 Loss 4.05738628692627, Acc 6.651428571428571, Val Loss 3.5746582946777345, Val Acc 10.42\n",
      "Speed train 224.18908188445624imgs/s val 679.7794406874162imgs/s\n",
      "Epoch 13 Loss 3.7937849918910436, Acc 10.214285714285714, Val Loss 3.7133444229125976, Val Acc 7.04\n",
      "Speed train 224.10400818300923imgs/s val 679.113765396056imgs/s\n",
      "Epoch 14 Loss 3.2694146046229773, Acc 19.425714285714285, Val Loss 3.6981288192749022, Val Acc 30.78\n",
      "Speed train 224.13154184979035imgs/s val 680.3901503258012imgs/s\n",
      "Epoch 15 Loss 2.7287981418064664, Acc 31.591428571428573, Val Loss 2.7384859634399414, Val Acc 43.0\n",
      "Speed train 224.08093870796063imgs/s val 680.9818793692156imgs/s\n",
      "Epoch 16 Loss 2.4017765145438057, Acc 40.222857142857144, Val Loss 2.373513427734375, Val Acc 55.06\n",
      "Speed train 224.0240886741711imgs/s val 680.9301175413847imgs/s\n",
      "Epoch 17 Loss 1.9575243755885532, Acc 50.81428571428572, Val Loss 1.8042015686035155, Val Acc 60.3\n",
      "Speed train 224.023773128244imgs/s val 679.3334877190623imgs/s\n",
      "Epoch 18 Loss 1.8670056664603096, Acc 52.754285714285714, Val Loss 1.7974752388000488, Val Acc 59.12\n",
      "Speed train 224.04512456698183imgs/s val 677.1390343683371imgs/s\n",
      "Epoch 19 Loss 1.6107693487439836, Acc 58.48571428571429, Val Loss 2.0469212783813475, Val Acc 66.72\n",
      "Speed train 223.87349316639512imgs/s val 678.0275851549168imgs/s\n",
      "Epoch 20 Loss 1.7171708895547049, Acc 56.642857142857146, Val Loss 1.8279149505615235, Val Acc 65.96\n",
      "Speed train 223.90746153613367imgs/s val 677.346826146328imgs/s\n",
      "Epoch 21 Loss 1.2915482904706683, Acc 65.76285714285714, Val Loss 1.3189221771240234, Val Acc 68.62\n",
      "Speed train 223.9696856948392imgs/s val 680.7548137514192imgs/s\n",
      "Epoch 22 Loss 1.1914144684110368, Acc 68.43714285714286, Val Loss 1.0220409889221191, Val Acc 67.86\n",
      "Speed train 223.9880397987904imgs/s val 679.2464221336535imgs/s\n",
      "Epoch 23 Loss 1.0181893185751778, Acc 72.81428571428572, Val Loss 0.6417443874359131, Val Acc 75.94\n",
      "Speed train 223.8553653733427imgs/s val 678.307898091706imgs/s\n",
      "Epoch 24 Loss 0.9370736787523543, Acc 75.12, Val Loss 1.0853789276123047, Val Acc 76.04\n",
      "Speed train 223.88541791918547imgs/s val 681.0767780628761imgs/s\n",
      "Epoch 25 Loss 0.858675898034232, Acc 76.78, Val Loss 0.6966076656341553, Val Acc 76.14\n",
      "Speed train 223.89225800022191imgs/s val 677.8778524201055imgs/s\n",
      "Epoch 26 Loss 0.911681534739903, Acc 75.55142857142857, Val Loss 1.2748067726135255, Val Acc 75.28\n",
      "Speed train 223.8182146384268imgs/s val 678.3320102990313imgs/s\n",
      "Epoch 27 Loss 0.7263422344616481, Acc 80.27428571428571, Val Loss 0.8662283229827881, Val Acc 76.24\n",
      "Speed train 223.84745302957168imgs/s val 677.5476529925469imgs/s\n",
      "Epoch 28 Loss 0.7096801671164377, Acc 80.64, Val Loss 0.4879303056716919, Val Acc 78.4\n",
      "Speed train 223.82907588526868imgs/s val 679.641487715212imgs/s\n",
      "Epoch 29 Loss 0.8400143226759774, Acc 77.24857142857142, Val Loss 0.48099885005950926, Val Acc 76.28\n",
      "Speed train 223.79857240610696imgs/s val 677.2004991093537imgs/s\n",
      "Epoch 30 Loss 0.6340663018226623, Acc 82.75428571428571, Val Loss 0.3814028434753418, Val Acc 77.88\n",
      "Speed train 223.69922632756038imgs/s val 677.8182801677248imgs/s\n",
      "Epoch 31 Loss 0.6143715186391558, Acc 83.12571428571428, Val Loss 1.5937435668945312, Val Acc 55.9\n",
      "Speed train 223.73651624919145imgs/s val 677.9854113416313imgs/s\n",
      "Epoch 32 Loss 0.6921936396871294, Acc 80.92285714285714, Val Loss 0.6802982173919677, Val Acc 74.06\n",
      "Speed train 223.6073489617921imgs/s val 675.2316887830606imgs/s\n",
      "Epoch 33 Loss 0.6144891169275556, Acc 83.18, Val Loss 0.46930033054351805, Val Acc 76.34\n",
      "Speed train 223.57683437615302imgs/s val 677.3064650659536imgs/s\n",
      "Epoch 34 Loss 0.568616727393014, Acc 84.28, Val Loss 0.4940680891036987, Val Acc 79.08\n",
      "Speed train 223.53035806536994imgs/s val 675.1831667819793imgs/s\n",
      "Epoch 35 Loss 0.5646722382409232, Acc 84.30571428571429, Val Loss 0.4494327730178833, Val Acc 79.24\n",
      "Speed train 223.53708898876783imgs/s val 676.0269683297067imgs/s\n",
      "Epoch 36 Loss 0.977967550604684, Acc 74.1, Val Loss 0.8460039363861084, Val Acc 74.68\n",
      "Speed train 223.70501802195884imgs/s val 675.9956110209276imgs/s\n",
      "Epoch 37 Loss 0.7239568670545306, Acc 80.12857142857143, Val Loss 1.048443465423584, Val Acc 80.18\n",
      "Speed train 223.7010776600049imgs/s val 677.2457684330305imgs/s\n",
      "Epoch 38 Loss 0.5576571273531232, Acc 84.37714285714286, Val Loss 0.7641737712860107, Val Acc 78.96\n",
      "Speed train 223.92617365625426imgs/s val 679.0357044899268imgs/s\n",
      "Epoch 39 Loss 0.4953382140840803, Acc 86.30285714285715, Val Loss 0.9396348545074463, Val Acc 80.98\n",
      "Speed train 223.80456065939208imgs/s val 677.9383775081458imgs/s\n"
     ]
    }
   ],
   "source": [
    "EPOCHS = 40\n",
    "for epoch in range(EPOCHS):\n",
    "\n",
    "    train_loss = 0\n",
    "    train_accuracy = 0\n",
    "    train_samples = 0\n",
    "\n",
    "    val_loss = 0\n",
    "    val_accuracy = 0\n",
    "    val_samples = 0\n",
    "\n",
    "    start = time.time()\n",
    "    for imgs, labels in img_train:\n",
    "        imgs = imgs.cuda(gpu)\n",
    "        labels = labels.cuda(gpu)\n",
    "\n",
    "        preds = model(imgs)\n",
    "\n",
    "        loss = loss_object(preds, labels)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_samples += imgs.shape[0]\n",
    "\n",
    "        train_loss += loss.item()\n",
    "        train_accuracy += (preds.argmax(dim=1) == labels).sum().item()\n",
    "\n",
    "    train_samples_per_second = train_samples / (time.time() - start)\n",
    "\n",
    "    start = time.time()\n",
    "    for imgs, labels in img_val:\n",
    "        imgs = imgs.cuda(gpu)\n",
    "        labels = labels.cuda(gpu)\n",
    "        model.eval()\n",
    "        preds = model(imgs)\n",
    "        model.train()\n",
    "        val_loss += loss.item()\n",
    "        val_accuracy += (preds.argmax(dim=1) == labels).sum().item()\n",
    "\n",
    "        val_samples += imgs.shape[0]\n",
    "\n",
    "    val_samples_per_second = val_samples / (time.time() - start)\n",
    "\n",
    "    print(\n",
    "        \"Epoch {} Loss {}, Acc {}, Val Loss {}, Val Acc {}\".format(\n",
    "            epoch,\n",
    "            train_loss * batch_size / train_samples,\n",
    "            train_accuracy * 100 / train_samples,\n",
    "            val_loss * batch_size / val_samples,\n",
    "            val_accuracy * 100 / val_samples,\n",
    "        )\n",
    "    )\n",
    "    print(\n",
    "        \"Speed train {}imgs/s val {}imgs/s\".format(\n",
    "            train_samples_per_second, val_samples_per_second\n",
    "        )\n",
    "    )\n",
    "\n",
    "    torch.save((model.state_dict(), optimizer.state_dict()), \"model_ckpt.pth\")\n",
    "    torch.save(\n",
    "        (model.state_dict(), optimizer.state_dict()),\n",
    "        \"model_ckpt-{:04d}.pth\".format(epoch),\n",
    "    )\n",
    "\n",
    "    # 每个 epoch 保存一下模型，需要注意每次\n",
    "    # 保存要用一个不同的名字，不然会导致覆盖，\n",
    "    # 同时还要关注一下磁盘空间占用，防止太多\n",
    "    # chekcpoint 占满磁盘空间导致错误。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 一些技巧\n",
    "\n",
    "因为这里定义的模型比较大，同时训练的数据也比较多，每个 epoch 用时较长，因此，如果代码有 bug 的话，经过一次 epoch 再去 debug 效率比较低。\n",
    "\n",
    "这种情况下，我们使用的数据生成过程又是自己手动指定数据数量的，因此可以尝试缩减模型规模，定义小一些的数据集来快速验证代码。在这个例子里，我们可以通过注释模型中的卷积和全连接层的代码来缩减模型尺寸，通过修改训练循环里面的数据数量来缩减数据数量。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练的速度很慢\n",
    "类似的网络结构和参数，TF里面 20epochs能达到90%的准确率，这里要40epochs才能到86%，应该是哪里有什么问题，我再看看怎么解决。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
